{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,614 INFO sqlalchemy.engine.base.Engine SELECT CAST('test plain returns' AS VARCHAR(60)) AS anon_1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: SELECT CAST('test plain returns' AS VARCHAR(60)) AS anon_1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,618 INFO sqlalchemy.engine.base.Engine ()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: ()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,624 INFO sqlalchemy.engine.base.Engine SELECT CAST('test unicode returns' AS VARCHAR(60)) AS anon_1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: SELECT CAST('test unicode returns' AS VARCHAR(60)) AS anon_1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,626 INFO sqlalchemy.engine.base.Engine ()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: ()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,631 INFO sqlalchemy.engine.base.Engine PRAGMA table_info(\"posts\")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: PRAGMA table_info(\"posts\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,635 INFO sqlalchemy.engine.base.Engine ()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: ()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,639 INFO sqlalchemy.engine.base.Engine PRAGMA table_info(\"replys\")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: PRAGMA table_info(\"replys\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,643 INFO sqlalchemy.engine.base.Engine ()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: ()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,649 INFO sqlalchemy.engine.base.Engine PRAGMA table_info(\"upusers\")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: PRAGMA table_info(\"upusers\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,653 INFO sqlalchemy.engine.base.Engine ()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: ()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,661 INFO sqlalchemy.engine.base.Engine PRAGMA table_info(\"shangusers\")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: PRAGMA table_info(\"shangusers\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,664 INFO sqlalchemy.engine.base.Engine ()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: ()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,671 INFO sqlalchemy.engine.base.Engine PRAGMA table_info(\"rawcontents\")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: PRAGMA table_info(\"rawcontents\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 14:45:32,675 INFO sqlalchemy.engine.base.Engine ()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 14:45:32] INFO    | sqlalchemy.engine.base.Engine: ()\n"
     ]
    }
   ],
   "source": [
    "from utils.log import getLogger\n",
    "import pandas as pd\n",
    "from model.db import DB_ENGINE\n",
    "from sklearn.linear_model import SGDClassifier\n",
    "from sklearn.model_selection import cross_val_score\n",
    "import numpy as np\n",
    "import pickle\n",
    "import os\n",
    "\n",
    "logger = getLogger('predict')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 15:28:36,290 INFO sqlalchemy.engine.base.Engine PRAGMA table_info(\"SELECT rid, tag, vector FROM rawcontents WHERE assure>0.5\")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 15:28:36] INFO    | sqlalchemy.engine.base.Engine: PRAGMA table_info(\"SELECT rid, tag, vector FROM rawcontents WHERE assure>0.5\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 15:28:36,321 INFO sqlalchemy.engine.base.Engine ()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 15:28:36] INFO    | sqlalchemy.engine.base.Engine: ()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 15:28:36,328 INFO sqlalchemy.engine.base.Engine SELECT rid, tag, vector FROM rawcontents WHERE assure>0.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 15:28:36] INFO    | sqlalchemy.engine.base.Engine: SELECT rid, tag, vector FROM rawcontents WHERE assure>0.5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-12-20 15:28:36,332 INFO sqlalchemy.engine.base.Engine ()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-12-20 15:28:36] INFO    | sqlalchemy.engine.base.Engine: ()\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rid</th>\n",
       "      <th>tag</th>\n",
       "      <th>vector</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>1.0</td>\n",
       "      <td>b'\\x80\\x03cnumpy.core.multiarray\\n_reconstruct...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>1.0</td>\n",
       "      <td>b'\\x80\\x03cnumpy.core.multiarray\\n_reconstruct...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>b'\\x80\\x03cnumpy.core.multiarray\\n_reconstruct...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>1.0</td>\n",
       "      <td>b'\\x80\\x03cnumpy.core.multiarray\\n_reconstruct...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>1.0</td>\n",
       "      <td>b'\\x80\\x03cnumpy.core.multiarray\\n_reconstruct...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   rid  tag                                             vector\n",
       "0    1  1.0  b'\\x80\\x03cnumpy.core.multiarray\\n_reconstruct...\n",
       "1    2  1.0  b'\\x80\\x03cnumpy.core.multiarray\\n_reconstruct...\n",
       "2    3  0.0  b'\\x80\\x03cnumpy.core.multiarray\\n_reconstruct...\n",
       "3    4  1.0  b'\\x80\\x03cnumpy.core.multiarray\\n_reconstruct...\n",
       "4    5  1.0  b'\\x80\\x03cnumpy.core.multiarray\\n_reconstruct..."
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labeled = pd.read_sql(\n",
    "    'SELECT rid, tag, vector FROM rawcontents WHERE assure>0.5',\n",
    "    DB_ENGINE\n",
    ")\n",
    "labeled.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rid</th>\n",
       "      <th>tag</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>1.018510e+05</td>\n",
       "      <td>101851.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>5.872414e+05</td>\n",
       "      <td>0.972538</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>4.706109e+05</td>\n",
       "      <td>0.163425</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>1.000000e+00</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>1.803830e+05</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>4.275950e+05</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>1.011050e+06</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>1.587761e+06</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                rid            tag\n",
       "count  1.018510e+05  101851.000000\n",
       "mean   5.872414e+05       0.972538\n",
       "std    4.706109e+05       0.163425\n",
       "min    1.000000e+00       0.000000\n",
       "25%    1.803830e+05       1.000000\n",
       "50%    4.275950e+05       1.000000\n",
       "75%    1.011050e+06       1.000000\n",
       "max    1.587761e+06       1.000000"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labeled.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = labeled.groupby('vector')['tag'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = list(pd.Series(train.index).apply(pickle.loads))\n",
    "y = list(train.values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(DATA_ROOT/'train_data_X.bin', 'wb') as f:\n",
    "    pickle.dump(X, f)\n",
    "\n",
    "with open(DATA_ROOT/'train_data_y.bin', 'wb') as f:\n",
    "    pickle.dump(y, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(DATA_ROOT/'train_data_X.bin', 'rb') as f:\n",
    "    X = pickle.load(f)\n",
    "\n",
    "with open(DATA_ROOT/'train_data_y.bin', 'rb') as f:\n",
    "    y = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import ensemble, svm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clf = SGDClassifier(\n",
    "#     random_state=ord(os.urandom(1)),\n",
    "#     max_iter=512,\n",
    "#     tol=1e-3,\n",
    "#     penalty='elasticnet',\n",
    "#     loss='modified_huber',\n",
    "#     fit_intercept=False\n",
    "# )\n",
    "clf = svm.SVC(kernel='rbf', gamma='scale')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.99 (+/- 0.00)\n"
     ]
    }
   ],
   "source": [
    "scores = cross_val_score(clf, X, y, cv=3)\n",
    "print(\"Accuracy: %0.2f (+/- %0.2f)\" % (scores.mean(), scores.std() * 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n",
       "  decision_function_shape='ovr', degree=3, gamma='scale', kernel='rbf',\n",
       "  max_iter=-1, probability=False, random_state=None, shrinking=True,\n",
       "  tol=0.001, verbose=False)"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_sql('SELECT rid, vector FROM rawcontents', DB_ENGINE)\n",
    "test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test = list(test['vector'].apply(pickle.loads))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_tag = clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.08984375"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict_tag.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rid</th>\n",
       "      <th>content</th>\n",
       "      <th>tag</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>86</td>\n",
       "      <td>QC商城就是运用了区块链技术融合线上商城线下门店的社交新零售平台\\n \\n 他是一个将虚拟经...</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>87</td>\n",
       "      <td>《我不是药神》上映以来好评如潮，票房大涨，是一部叫好又叫座的佳作，不仅豆瓣评分9.0，更被人...</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>100</td>\n",
       "      <td>2018/7/19:这周我很少给单子。因为自己事情多，今天弄下操作计划，先讲讲黄金，这两天我...</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>101</td>\n",
       "      <td>这好像和区块链没啥关系，是外汇方面的吧？</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>102</td>\n",
       "      <td>@真心交人诚心办事 2018-07-19 09:03:48\\n \\n 这好像和区块链没啥关系...</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   rid                                            content  tag\n",
       "0   86  QC商城就是运用了区块链技术融合线上商城线下门店的社交新零售平台\\n \\n 他是一个将虚拟经...  0.0\n",
       "1   87  《我不是药神》上映以来好评如潮，票房大涨，是一部叫好又叫座的佳作，不仅豆瓣评分9.0，更被人...  0.0\n",
       "2  100  2018/7/19:这周我很少给单子。因为自己事情多，今天弄下操作计划，先讲讲黄金，这两天我...  0.0\n",
       "3  101                               这好像和区块链没啥关系，是外汇方面的吧？  1.0\n",
       "4  102  @真心交人诚心办事 2018-07-19 09:03:48\\n \\n 这好像和区块链没啥关系...  0.0"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result = pd.DataFrame(data={\n",
    "    'rid': test['rid'],\n",
    "    'content': test['content'],\n",
    "    'tag': predict_tag\n",
    "})\n",
    "result.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rid</th>\n",
       "      <th>tag</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>2048.000000</td>\n",
       "      <td>2048.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>1171.120605</td>\n",
       "      <td>0.089844</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>619.894848</td>\n",
       "      <td>0.286028</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>86.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>631.500000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>1170.500000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>1711.250000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>2237.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               rid          tag\n",
       "count  2048.000000  2048.000000\n",
       "mean   1171.120605     0.089844\n",
       "std     619.894848     0.286028\n",
       "min      86.000000     0.000000\n",
       "25%     631.500000     0.000000\n",
       "50%    1170.500000     0.000000\n",
       "75%    1711.250000     0.000000\n",
       "max    2237.000000     1.000000"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "predict_proba is not available when  probability=False",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-10-df3d5f441993>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mproba\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mpredict_tag\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mentropies\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstats\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdistributions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mentropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mproba\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbase\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/sklearn/svm/base.py\u001b[0m in \u001b[0;36mpredict_proba\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    607\u001b[0m         \u001b[0mdatasets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    608\u001b[0m         \"\"\"\n\u001b[0;32m--> 609\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    610\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_predict_proba\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/sklearn/svm/base.py\u001b[0m in \u001b[0;36m_check_proba\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    574\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_check_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    575\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprobability\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 576\u001b[0;31m             raise AttributeError(\"predict_proba is not available when \"\n\u001b[0m\u001b[1;32m    577\u001b[0m                                  \" probability=False\")\n\u001b[1;32m    578\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_impl\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'c_svc'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'nu_svc'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: predict_proba is not available when  probability=False"
     ]
    }
   ],
   "source": [
    "proba = clf.predict_proba(X_test)\n",
    "entropies = stats.distributions.entropy(proba.T, base=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = pd.DataFrame(data={\n",
    "    'rid': test['rid'],\n",
    "    'content': test['content'],\n",
    "    'entropy': entropies,\n",
    "    'tag': predict_tag\n",
    "})\n",
    "result.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "pd.value_counts(pd.cut(entropies, np.linspace(0,1,11)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result.sort_values('entropy', ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ds",
   "language": "python",
   "name": "ds"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
